import torch
import argparse
import kgdlg
import json
from torch import cuda
import progressbar
import kgdlg.utils.misc_utils as utils

def indices_lookup(indices,fields):

    words = [fields['tgt'].vocab.itos[i] for i in indices]
    sent = ' '.join(words)
    return sent


def batch_indices_lookup(batch_indices,fields):

    batch_sents = []
    for sent_indices in batch_indices:
        sent = indices_lookup(sent_indices,fields)
        batch_sents.append(sent)
    return batch_sents



def inference_file(translator, 
                   data_iter,
                   test_out, fields,
                   ):

    print('start decoding ...')
    with open(test_out, 'w', encoding='utf8') as tgt_file:
        bar = progressbar.ProgressBar()
        for batch in bar(data_iter):
            ret = translator.inference_batch(batch)
            batch_sents = batch_indices_lookup(ret['predictions'][0], fields)
            for sent in batch_sents:
                tgt_file.write(sent + '\n')

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-test_data", type=str)
    parser.add_argument("-tgt_test", type=str)
    parser.add_argument("-test_out", type=str)
    parser.add_argument("-config", type=str)
    parser.add_argument("-model", type=str)
    parser.add_argument("-vocab", type=str)
    parser.add_argument("-cluster_dict", type=str)
    parser.add_argument('-gpuid', default=[], nargs='+', type=int)
    parser.add_argument("-beam_size", type=int)
    parser.add_argument("-decode_max_length", type=int)
    parser.add_argument("-topk", type=int)
    parser.add_argument("-errormo",type=int)
    parser.add_argument("-train_mode", default=999,type=int)


    args = parser.parse_args()
    opt = utils.load_hparams(args.config)
    opt.train_mode = args.train_mode
    opt.drop_out_set = None

    use_cuda = False
    device = None
    if args.gpuid:
        cuda.set_device(args.gpuid[0])
        device = torch.device('cuda',args.gpuid[0])
        use_cuda = True


    fields = kgdlg.IO.load_fields_from_vocab(
                torch.load(args.vocab))

    if args.errormo in (2,6,4,7):
        print("带有tgt")
        tgt_dataset = kgdlg.IO.TgtInferDataset(
            data_path=args.tgt_test,
            fields=[('src', fields["src"]),('tgt',fields["tgt"])])

        test_data_iter = kgdlg.IO.OrderedIterator(
            dataset=tgt_dataset, device=device,
            batch_size=1, train=False, sort=False,
            sort_within_batch=True, shuffle=False)
    else:
        test_dataset = kgdlg.IO.InferDataset(
            data_path=args.test_data,
            fields=[('src', fields["src"])])

        test_data_iter = kgdlg.IO.OrderedIterator(
                    dataset=test_dataset, device=device,
                    batch_size=1, train=False, sort=False,
                    sort_within_batch=True, shuffle=False)



    model = kgdlg.ModelConstructor.create_joint_model(opt,fields)


    print('Loading parameters ...')
    model.load_checkpoint(args.model)
    if use_cuda:
        model = model.cuda()    

    translator = kgdlg.Inferer(model=model, 
                                fields=fields,
                                beam_size=args.beam_size, 
                                n_best=1,
                                max_length=args.decode_max_length,
                                global_scorer=None,
                                cuda=use_cuda,
                               mode = args.errormo)

    inference_file(translator, test_data_iter, args.test_out, fields)

if __name__ == '__main__':
    main()